"""
Code adapted from https://github.com/if-loops/selective-synaptic-dampening/tree/main/src
https://arxiv.org/abs/2308.07707
"""

from typing import Any, Tuple
from torchvision.datasets import CIFAR100, CIFAR10, ImageFolder
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import MNIST
import os
import glob
from PIL import Image

# Improves model performance (https://github.com/weiaicunzai/pytorch-cifar100)
CIFAR_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

# Cropping etc. to improve performance of the model (details see https://github.com/weiaicunzai/pytorch-cifar100)
transform_train_from_scratch = [
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]

transform_unlearning = [
    # transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]

transform_test = [
    # transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]

class Cifar100(CIFAR100):
    def __init__(self, root, train, unlearning, download, img_size=32):
        if train:
            if unlearning:
                transform = transform_unlearning
            else:
                transform = transform_train_from_scratch
        else:
            transform = transform_test
        transform.append(transforms.Resize(img_size))
        transform = transforms.Compose(transform)

        super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
        x, y = super().__getitem__(index)
        return x, torch.Tensor([]), y


class Cifar20(CIFAR100):
    def __init__(self, root, train, unlearning, download, img_size=32, indices = None):
        if train:
            if unlearning:
                transform = transform_unlearning
            else:
                transform = transform_train_from_scratch
        else:
            transform = transform_test
        transform.append(transforms.Resize(img_size))
        transform = transforms.Compose(transform)

        super().__init__(root=root, train=train, download=download, transform=transform)

        # This map is for the matching of subclases to the superclasses. E.g., rocket (69) to Vehicle2 (19:)
        # Taken from https://github.com/vikram2000b/bad-teaching-unlearning
        self.coarse_map = {
            0: [4, 30, 55, 72, 95],
            1: [1, 32, 67, 73, 91],
            2: [54, 62, 70, 82, 92],
            3: [9, 10, 16, 28, 61],
            4: [0, 51, 53, 57, 83],
            5: [22, 39, 40, 86, 87],
            6: [5, 20, 25, 84, 94],
            7: [6, 7, 14, 18, 24],
            8: [3, 42, 43, 88, 97],
            9: [12, 17, 37, 68, 76],
            10: [23, 33, 49, 60, 71],
            11: [15, 19, 21, 31, 38],
            12: [34, 63, 64, 66, 75],
            13: [26, 45, 77, 79, 99],
            14: [2, 11, 35, 46, 98],
            15: [27, 29, 44, 78, 93],
            16: [36, 50, 65, 74, 80],
            17: [47, 52, 56, 59, 96],
            18: [8, 13, 48, 58, 90],
            19: [41, 69, 81, 85, 89],
        }
        if indices is not None:
            self.data = self.data[indices]
            self.targets = [self.targets[i] for i in indices]

    def __getitem__(self, index):
        x, y = super().__getitem__(index)
        coarse_y = None
        for i in range(20):
            for j in self.coarse_map[i]:
                if y == j:
                    coarse_y = i
                    break
            if coarse_y != None:
                break
        if coarse_y == None:
            print(y)
            assert coarse_y != None
        return x, y, coarse_y
    



class Cifar10(CIFAR10):
    def __init__(self, root, train, unlearning, download, img_size=32, indices=None):
        if train:
            if unlearning:
                transform = transform_unlearning
            else:
                transform = transform_train_from_scratch
        else:
            transform = transform_test
        transform.append(transforms.Resize(img_size))
        transform = transforms.Compose(transform)

        super().__init__(root=root, train=train, download=download, transform=transform)

        # Subsample data and targets if indices are given
        if indices is not None:
            self.data = self.data[indices]
            self.targets = [self.targets[i] for i in indices]

    def __getitem__(self, index):
        x, y = super().__getitem__(index)
        return x, torch.Tensor([]), y


MNIST_MEAN = (0.1307,)
MNIST_STD = (0.3081,)

transform_mnist_train = [
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(MNIST_MEAN, MNIST_STD),
]

transform_mnist_test = [
    transforms.ToTensor(),
    transforms.Normalize(MNIST_MEAN, MNIST_STD),
]

class Mnist(MNIST):
    def __init__(self, root, train, unlearning, download, img_size=32, indices=None):
        if train:
            if unlearning:
                transform = transform_mnist_test
            else:
                transform = transform_mnist_train
        else:
            transform = transform_mnist_test

        transform.insert(0, transforms.Grayscale(num_output_channels=3))

        transform.append(transforms.Resize(img_size))
        transform = transforms.Compose(transform)

        super().__init__(root=root, train=train, download=download, transform=transform)
        
        if indices is not None:
            self.data = self.data[indices]
            self.targets = [self.targets[i] for i in indices]

    def __getitem__(self, index):
        x, y = super().__getitem__(index)
        return x, torch.Tensor([]), y
    
class UnLearningData(Dataset):
    def __init__(self, forget_data, retain_data):
        super().__init__()
        self.forget_data = forget_data
        self.retain_data = retain_data
        self.forget_len = len(forget_data)
        self.retain_len = len(retain_data)

    def __len__(self):
        return self.retain_len + self.forget_len

    def __getitem__(self, index):
        if index < self.forget_len:
            x = self.forget_data[index][0]
            y = 1
            return x, y
        else:
            x = self.retain_data[index - self.forget_len][0]
            y = 0
            return x, y


# MUCAC recommended transforms
mucac_train_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
])

mucac_test_transform = transforms.Compose([
    transforms.Resize(128), 
    transforms.ToTensor(),
])

# Identity and label loading (run once per script)
def load_mucac_meta(path):
    # Load identities
    with open(os.path.join(path, "CelebA-HQ-identity.txt")) as f:
        identities = dict(line.strip().split() for line in f.readlines())
    
    # Load attributes
    attributes_map = {"smiling": 32}
    label_map = {}
    with open(os.path.join(path, "CelebA-HQ-attribute.txt")) as f:
        lines = f.readlines()[2:]
        for line in lines:
            parts = line.strip().split()
            file_name = parts[0]
            label_map[file_name] = {attr: int(parts[idx]) for attr, idx in attributes_map.items()}
    return identities, label_map


class MUCAC(Dataset):
    def __init__(self, root, train, unlearning, download=False, img_size=128, indices = None):#img size was 128 for resnet
        self.root = root
        self.train = train
        self.unlearning = unlearning
        self.img_size = img_size

        # Determine split type
        if train:
            split = "train" if not unlearning else "forget"
        else:
            split = "test"

        self.train_index = 190
        self.forget_index = 1970
        self.unseen_index = 4855
        self.img_dir = os.path.join(self.root, "CelebAMask-HQ", "CelebA-HQ-img")

        
        print(self.img_dir)
        # Load metadata
        self.identities, self.label_map = load_mucac_meta(self.root)

        # Transforms
        if train and not unlearning:
            self.transform = mucac_train_transform
        else:
            self.transform = mucac_test_transform

        # Select files based on identity range
        self.image_paths = []
        self.labels = []

        for img_path in glob.glob(os.path.join(self.img_dir, "*.jpg")):
            file_name = os.path.basename(img_path)
            identity = int(self.identities[file_name])
            smiling = self.label_map[file_name]["smiling"]
            if smiling == -1:
                smiling = 0

            if (split == "train" and self.train_index <= identity < self.forget_index) or \
               (split == "forget" and self.forget_index <= identity < self.unseen_index) or \
               (split == "test" and identity < self.train_index):
                self.image_paths.append(img_path)
                self.labels.append(smiling)

        if indices is not None:
            self.image_paths = [self.image_paths[i] for i in indices]
            self.labels = [self.labels[i] for i in indices]



    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        img = self.transform(img)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return img, torch.Tensor([]), label
